49d31f8
@@ -21,6 +21,7 @@
 import static org.apache.commons.lang.StringUtils.join;
 import static org.apache.commons.lang.StringUtils.repeat;
 
+import com.google.common.collect.Lists;
 import java.sql.Connection;
 import java.sql.SQLException;
 import java.text.ParseException;
@@ -31,15 +32,11 @@
 import java.util.List;
 import java.util.Map;
 import java.util.TreeMap;
-
 import javax.jdo.PersistenceManager;
 import javax.jdo.Query;
 import javax.jdo.Transaction;
 import javax.jdo.datastore.JDOConnection;
-
 import org.apache.commons.lang.StringUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.hive.conf.HiveConf;
 import org.apache.hadoop.hive.conf.HiveConf.ConfVars;
@@ -72,8 +69,8 @@
 import org.apache.hadoop.hive.serde.serdeConstants;
 import org.apache.hive.common.util.BloomFilter;
 import org.datanucleus.store.rdbms.query.ForwardQueryResult;
-
-import com.google.common.collect.Lists;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * This class contains the optimizations for MetaStore that rely on direct SQL access to
@@ -351,14 +348,18 @@
public Database getDatabase(String dbName) throws MetaException{
    * @param partNames Partition names to get.
    * @return List of partitions.
    */
-  public List<Partition> getPartitionsViaSqlFilter(
-      String dbName, String tblName, List<String> partNames) throws MetaException {
+  public List<Partition> getPartitionsViaSqlFilter(final String dbName, final String tblName,
+      List<String> partNames) throws MetaException {
     if (partNames.isEmpty()) {
       return new ArrayList<Partition>();
     }
-    return getPartitionsViaSqlFilterInternal(dbName, tblName, null,
-        "\"PARTITIONS\".\"PART_NAME\" in (" + makeParams(partNames.size()) + ")",
-        partNames, new ArrayList<String>(), null);
+    return runBatched(partNames, new Batchable<String, Partition>() {
+      public List<Partition> run(List<String> input) throws MetaException {
+        String filter = "\"PARTITIONS\".\"PART_NAME\" in (" + makeParams(input.size()) + ")";
+        return getPartitionsViaSqlFilterInternal(dbName, tblName, null, filter, input,
+            new ArrayList<String>(), null);
+      }
+    });
   }
 
   /**
@@ -450,11 +451,10 @@
private boolean isViewTable(String dbName, String tblName) throws MetaException
    * @return List of partition objects.
    */
   private List<Partition> getPartitionsViaSqlFilterInternal(String dbName, String tblName,
-      Boolean isView, String sqlFilter, List<? extends Object> paramsForFilter,
+      final Boolean isView, String sqlFilter, List<? extends Object> paramsForFilter,
       List<String> joinsForFilter, Integer max) throws MetaException {
     boolean doTrace = LOG.isDebugEnabled();
-    dbName = dbName.toLowerCase();
-    tblName = tblName.toLowerCase();
+    final String dbNameLcase = dbName.toLowerCase(), tblNameLcase = tblName.toLowerCase();
     // We have to be mindful of order during filtering if we are not returning all partitions.
     String orderForFilter = (max != null) ? " order by \"PART_NAME\" asc" : "";
 
@@ -475,8 +475,8 @@
private boolean isViewTable(String dbName, String tblName) throws MetaException
       + join(joinsForFilter, ' ')
       + (StringUtils.isBlank(sqlFilter) ? "" : (" where " + sqlFilter)) + orderForFilter;
     Object[] params = new Object[paramsForFilter.size() + 2];
-    params[0] = tblName;
-    params[1] = dbName;
+    params[0] = tblNameLcase;
+    params[1] = dbNameLcase;
     for (int i = 0; i < paramsForFilter.size(); ++i) {
       params[i + 2] = paramsForFilter.get(i);
     }
@@ -493,23 +493,18 @@
private boolean isViewTable(String dbName, String tblName) throws MetaException
       return new ArrayList<Partition>(); // no partitions, bail early.
     }
 
-    // Get full objects. For Oracle, do it in batches.
-    List<Partition> result = null;
-    if (batchSize != NO_BATCHING && batchSize < sqlResult.size()) {
-      result = new ArrayList<Partition>(sqlResult.size());
-      while (result.size() < sqlResult.size()) {
-        int toIndex = Math.min(result.size() + batchSize, sqlResult.size());
-        List<Object> batchedSqlResult = sqlResult.subList(result.size(), toIndex);
-        result.addAll(getPartitionsFromPartitionIds(dbName, tblName, isView, batchedSqlResult));
+    // Get full objects. For Oracle/etc. do it in batches.
+    List<Partition> result = runBatched(sqlResult, new Batchable<Object, Partition>() {
+      public List<Partition> run(List<Object> input) throws MetaException {
+        return getPartitionsFromPartitionIds(dbNameLcase, tblNameLcase, isView, input);
       }
-    } else {
-      result = getPartitionsFromPartitionIds(dbName, tblName, isView, sqlResult);
-    }
+    });
 
     query.closeAll();
     return result;
   }
 
+  /** Should be called with the list short enough to not trip up Oracle/etc. */
   private List<Partition> getPartitionsFromPartitionIds(String dbName, String tblName,
       Boolean isView, List<Object> partIdList) throws MetaException {
     boolean doTrace = LOG.isDebugEnabled();
@@ -846,7 +841,6 @@
private int getNumPartitionsViaSqlFilterInternal(String dbName, String tblName,
 
     long start = doTrace ? System.nanoTime() : 0;
     Query query = pm.newQuery("javax.jdo.query.SQL", queryText);
-    @SuppressWarnings("unchecked")
     int sqlResult = extractSqlInt(query.executeWithArray(params));
     long queryTime = doTrace ? System.nanoTime() : 0;
     timingTrace(doTrace, queryText, start, queryTime);
@@ -1079,10 +1073,10 @@
public void visit(LeafNode node) throws MetaException {
         return;
       }
 
-      // TODO: if Filter.g does date parsing for quoted strings, we'd need to verify there's no
-      //       type mismatch when string col is filtered by a string that looks like date.
+      // if Filter.g does date parsing for quoted strings, we'd need to verify there's no
+      // type mismatch when string col is filtered by a string that looks like date.
       if (colType == FilterType.Date && valType == FilterType.String) {
-        // TODO: Filter.g cannot parse a quoted date; try to parse date here too.
+        // Filter.g cannot parse a quoted date; try to parse date here too.
         try {
           nodeValue = new java.sql.Date(
               HiveMetaStore.PARTITION_DATE_FORMAT.get().parse((String)nodeValue).getTime());
@@ -1154,35 +1148,40 @@
public void visit(LeafNode node) throws MetaException {
     }
   }
 
-  public ColumnStatistics getTableStats(
-      String dbName, String tableName, List<String> colNames) throws MetaException {
+  public ColumnStatistics getTableStats(final String dbName, final String tableName,
+      List<String> colNames) throws MetaException {
     if (colNames.isEmpty()) {
       return null;
     }
-    boolean doTrace = LOG.isDebugEnabled();
-    long start = doTrace ? System.nanoTime() : 0;
-    String queryText = "select " + STATS_COLLIST + " from \"TAB_COL_STATS\" "
-      + " where \"DB_NAME\" = ? and \"TABLE_NAME\" = ? and \"COLUMN_NAME\" in ("
-      + makeParams(colNames.size()) + ")";
-    Query query = pm.newQuery("javax.jdo.query.SQL", queryText);
-    Object[] params = new Object[colNames.size() + 2];
-    params[0] = dbName;
-    params[1] = tableName;
-    for (int i = 0; i < colNames.size(); ++i) {
-      params[i + 2] = colNames.get(i);
-    }
-    Object qResult = executeWithArray(query, params, queryText);
-    long queryTime = doTrace ? System.nanoTime() : 0;
-    if (qResult == null) {
-      query.closeAll();
-      return null;
-    }
-    List<Object[]> list = ensureList(qResult);
+    final boolean doTrace = LOG.isDebugEnabled();
+    final String queryText0 = "select " + STATS_COLLIST + " from \"TAB_COL_STATS\" "
+          + " where \"DB_NAME\" = ? and \"TABLE_NAME\" = ? and \"COLUMN_NAME\" in (";
+    Batchable<String, Object[]> b = new Batchable<String, Object[]>() {
+      public List<Object[]> run(List<String> input) throws MetaException {
+        String queryText = queryText0 + makeParams(input.size()) + ")";
+        Object[] params = new Object[input.size() + 2];
+        params[0] = dbName;
+        params[1] = tableName;
+        for (int i = 0; i < input.size(); ++i) {
+          params[i + 2] = input.get(i);
+        }
+        long start = doTrace ? System.nanoTime() : 0;
+        Query query = pm.newQuery("javax.jdo.query.SQL", queryText);
+        Object qResult = executeWithArray(query, params, queryText);
+        timingTrace(doTrace, queryText0 + "...)", start, (doTrace ? System.nanoTime() : 0));
+        if (qResult == null) {
+          query.closeAll();
+          return null;
+        }
+        addQueryAfterUse(query);
+        return ensureList(qResult);
+      }
+    };
+    List<Object[]> list = runBatched(colNames, b);
     if (list.isEmpty()) return null;
     ColumnStatisticsDesc csd = new ColumnStatisticsDesc(true, dbName, tableName);
     ColumnStatistics result = makeColumnStats(list, csd, 0);
-    timingTrace(doTrace, queryText, start, queryTime);
-    query.closeAll();
+    b.closeAllQueries();
     return result;
   }
 
@@ -1191,7 +1190,7 @@
public AggrStats aggrColStatsForPartitions(String dbName, String tableName,
       throws MetaException {
     if (colNames.isEmpty() || partNames.isEmpty()) {
       LOG.debug("Columns is empty or partNames is empty : Short-circuiting stats eval");
-      return new AggrStats(new ArrayList<ColumnStatisticsObj>(),0); // Nothing to aggregate
+      return new AggrStats(new ArrayList<ColumnStatisticsObj>(), 0); // Nothing to aggregate
     }
     long partsFound = partsFoundForPartitions(dbName, tableName, partNames, colNames);
     List<ColumnStatisticsObj> colStatsList;
@@ -1203,9 +1202,8 @@
public AggrStats aggrColStatsForPartitions(String dbName, String tableName,
       float fpp = aggrStatsCache.getFalsePositiveProbability();
       int partitionsRequested = partNames.size();
       if (partitionsRequested > maxPartsPerCacheNode) {
-        colStatsList =
-            columnStatisticsObjForPartitions(dbName, tableName, partNames, colNames, partsFound,
-                useDensityFunctionForNDVEstimation);
+        colStatsList = columnStatisticsObjForPartitions(dbName, tableName, partNames, colNames,
+            partsFound, useDensityFunctionForNDVEstimation);
       } else {
         colStatsList = new ArrayList<ColumnStatisticsObj>();
         // Bloom filter for the new node that we will eventually add to the cache
@@ -1251,35 +1249,69 @@
private BloomFilter createPartsBloomFilter(int maxPartsPerCacheNode, float fpp,
     return bloomFilter;
   }
 
-  private long partsFoundForPartitions(String dbName, String tableName,
-      List<String> partNames, List<String> colNames) throws MetaException {
+  private long partsFoundForPartitions(final String dbName, final String tableName,
+      final List<String> partNames, List<String> colNames) throws MetaException {
     assert !colNames.isEmpty() && !partNames.isEmpty();
-    long partsFound = 0;
-    boolean doTrace = LOG.isDebugEnabled();
-    String queryText = "select count(\"COLUMN_NAME\") from \"PART_COL_STATS\""
+    final boolean doTrace = LOG.isDebugEnabled();
+    final String queryText0  = "select count(\"COLUMN_NAME\") from \"PART_COL_STATS\""
         + " where \"DB_NAME\" = ? and \"TABLE_NAME\" = ? "
-        + " and \"COLUMN_NAME\" in (" + makeParams(colNames.size()) + ")"
-        + " and \"PARTITION_NAME\" in (" + makeParams(partNames.size()) + ")"
+        + " and \"COLUMN_NAME\" in (%1$s) and \"PARTITION_NAME\" in (%2$s)"
         + " group by \"PARTITION_NAME\"";
-    long start = doTrace ? System.nanoTime() : 0;
-    Query query = pm.newQuery("javax.jdo.query.SQL", queryText);
-    Object qResult = executeWithArray(query, prepareParams(
-        dbName, tableName, partNames, colNames), queryText);
-    long end = doTrace ? System.nanoTime() : 0;
-    timingTrace(doTrace, queryText, start, end);
-    ForwardQueryResult fqr = (ForwardQueryResult) qResult;
-    Iterator<?> iter = fqr.iterator();
-    while (iter.hasNext()) {
-      if (extractSqlLong(iter.next()) == colNames.size()) {
-        partsFound++;
+    List<Long> allCounts = runBatched(colNames, new Batchable<String, Long>() {
+      public List<Long> run(final List<String> inputColName) throws MetaException {
+        return runBatched(partNames, new Batchable<String, Long>() {
+          public List<Long> run(List<String> inputPartNames) throws MetaException {
+            long partsFound = 0;
+            String queryText = String.format(queryText0,
+                makeParams(inputColName.size()), makeParams(inputPartNames.size()));
+            long start = doTrace ? System.nanoTime() : 0;
+            Query query = pm.newQuery("javax.jdo.query.SQL", queryText);
+            try {
+              Object qResult = executeWithArray(query, prepareParams(
+                  dbName, tableName, inputPartNames, inputColName), queryText);
+              long end = doTrace ? System.nanoTime() : 0;
+              timingTrace(doTrace, queryText, start, end);
+              ForwardQueryResult fqr = (ForwardQueryResult) qResult;
+              Iterator<?> iter = fqr.iterator();
+              while (iter.hasNext()) {
+                if (extractSqlLong(iter.next()) == inputColName.size()) {
+                  partsFound++;
+                }
+              }
+              return Lists.<Long>newArrayList(partsFound);
+            } finally {
+              query.closeAll();
+            }
+          }
+        });
       }
+    });
+    long partsFound = 0;
+    for (Long val : allCounts) {
+      partsFound += val;
     }
-    query.closeAll();
     return partsFound;
   }
 
-  private List<ColumnStatisticsObj> columnStatisticsObjForPartitions(String dbName,
-      String tableName, List<String> partNames, List<String> colNames, long partsFound,
+  private List<ColumnStatisticsObj> columnStatisticsObjForPartitions(final String dbName,
+    final String tableName, final List<String> partNames, List<String> colNames, long partsFound,
+    final boolean useDensityFunctionForNDVEstimation) throws MetaException {
+    final boolean areAllPartsFound = (partsFound == partNames.size());
+    return runBatched(colNames, new Batchable<String, ColumnStatisticsObj>() {
+      public List<ColumnStatisticsObj> run(final List<String> inputColNames) throws MetaException {
+        return runBatched(partNames, new Batchable<String, ColumnStatisticsObj>() {
+          public List<ColumnStatisticsObj> run(List<String> inputPartNames) throws MetaException {
+            return columnStatisticsObjForPartitionsBatch(dbName, tableName, inputPartNames,
+                inputColNames, areAllPartsFound, useDensityFunctionForNDVEstimation);
+          }
+        });
+      }
+    });
+  }
+
+  /** Should be called with the list short enough to not trip up Oracle/etc. */
+  private List<ColumnStatisticsObj> columnStatisticsObjForPartitionsBatch(String dbName,
+      String tableName, List<String> partNames, List<String> colNames, boolean areAllPartsFound,
       boolean useDensityFunctionForNDVEstimation) throws MetaException {
     // TODO: all the extrapolation logic should be moved out of this class,
     // only mechanical data retrieval should remain here.
@@ -1315,7 +1347,7 @@
private long partsFoundForPartitions(String dbName, String tableName,
     ForwardQueryResult fqr = null;
     // Check if the status of all the columns of all the partitions exists
     // Extrapolation is not needed.
-    if (partsFound == partNames.size()) {
+    if (areAllPartsFound) {
       queryText = commonPrefix + " and \"COLUMN_NAME\" in (" + makeParams(colNames.size()) + ")"
           + " and \"PARTITION_NAME\" in (" + makeParams(partNames.size()) + ")"
           + " group by \"COLUMN_NAME\", \"COLUMN_TYPE\"";
@@ -1408,15 +1440,10 @@
private long partsFoundForPartitions(String dbName, String tableName,
         // get sum for all columns to reduce the number of queries
         Map<String, Map<Integer, Object>> sumMap = new HashMap<String, Map<Integer, Object>>();
         queryText = "select \"COLUMN_NAME\", sum(\"NUM_NULLS\"), sum(\"NUM_TRUES\"), sum(\"NUM_FALSES\"), sum(\"NUM_DISTINCTS\")"
-            + " from \"PART_COL_STATS\""
-            + " where \"DB_NAME\" = ? and \"TABLE_NAME\" = ? "
-            + " and \"COLUMN_NAME\" in ("
-            + makeParams(extraColumnNameTypeParts.size())
-            + ")"
-            + " and \"PARTITION_NAME\" in ("
-            + makeParams(partNames.size())
-            + ")"
-            + " group by \"COLUMN_NAME\"";
+            + " from \"PART_COL_STATS\" where \"DB_NAME\" = ? and \"TABLE_NAME\" = ? "
+            + " and \"COLUMN_NAME\" in (" + makeParams(extraColumnNameTypeParts.size())
+            + ") and \"PARTITION_NAME\" in (" + makeParams(partNames.size())
+            + ") group by \"COLUMN_NAME\"";
         start = doTrace ? System.nanoTime() : 0;
         query = pm.newQuery("javax.jdo.query.SQL", queryText);
         List<String> extraColumnNames = new ArrayList<String>();
@@ -1517,8 +1544,7 @@
private long partsFoundForPartitions(String dbName, String tableName,
                     indexMap);
               }
             } else {
-              // if the aggregation type is avg, we use the average on the
-              // existing ones.
+              // if the aggregation type is avg, we use the average on the existing ones.
               queryText = "select "
                   + "avg((\"LONG_HIGH_VALUE\"-\"LONG_LOW_VALUE\")/cast(\"NUM_DISTINCTS\" as decimal)),"
                   + "avg((\"DOUBLE_HIGH_VALUE\"-\"DOUBLE_LOW_VALUE\")/\"NUM_DISTINCTS\"),"
@@ -1591,27 +1617,43 @@
private ColumnStatisticsObj prepareCSObjWithAdjustedNDV(Object[] row, int i,
     return params;
   }
 
-  public List<ColumnStatistics> getPartitionStats(String dbName, String tableName,
-      List<String> partNames, List<String> colNames) throws MetaException {
+  public List<ColumnStatistics> getPartitionStats(final String dbName, final String tableName,
+      final List<String> partNames, List<String> colNames) throws MetaException {
     if (colNames.isEmpty() || partNames.isEmpty()) {
       return Lists.newArrayList();
     }
-    boolean doTrace = LOG.isDebugEnabled();
-    long start = doTrace ? System.nanoTime() : 0;
-    String queryText = "select \"PARTITION_NAME\", " + STATS_COLLIST + " from \"PART_COL_STATS\""
-      + " where \"DB_NAME\" = ? and \"TABLE_NAME\" = ? and \"COLUMN_NAME\" in ("
-      + makeParams(colNames.size()) + ") AND \"PARTITION_NAME\" in ("
-      + makeParams(partNames.size()) + ") order by \"PARTITION_NAME\"";
+    final boolean doTrace = LOG.isDebugEnabled();
+    final String queryText0 = "select \"PARTITION_NAME\", " + STATS_COLLIST + " from "
+      + " \"PART_COL_STATS\" where \"DB_NAME\" = ? and \"TABLE_NAME\" = ? and \"COLUMN_NAME\""
+      + "  in (%1$s) AND \"PARTITION_NAME\" in (%2$s) order by \"PARTITION_NAME\"";
+    Batchable<String, Object[]> b = new Batchable<String, Object[]>() {
+      public List<Object[]> run(final List<String> inputColNames) throws MetaException {
+        Batchable<String, Object[]> b2 = new Batchable<String, Object[]>() {
+          public List<Object[]> run(List<String> inputPartNames) throws MetaException {
+            String queryText = String.format(queryText0,
+                makeParams(inputColNames.size()), makeParams(inputPartNames.size()));
+            long start = doTrace ? System.nanoTime() : 0;
+            Query query = pm.newQuery("javax.jdo.query.SQL", queryText);
+            Object qResult = executeWithArray(query, prepareParams(
+                dbName, tableName, inputPartNames, inputColNames), queryText);
+            timingTrace(doTrace, queryText0, start, (doTrace ? System.nanoTime() : 0));
+            if (qResult == null) {
+              query.closeAll();
+              return Lists.newArrayList();
+            }
+            addQueryAfterUse(query);
+            return ensureList(qResult);
+          }
+        };
+        try {
+          return runBatched(partNames, b2);
+        } finally {
+          addQueryAfterUse(b2);
+        }
+      }
+    };
+    List<Object[]> list = runBatched(colNames, b);
 
-    Query query = pm.newQuery("javax.jdo.query.SQL", queryText);
-    Object qResult = executeWithArray(query, prepareParams(
-        dbName, tableName, partNames, colNames), queryText);
-    long queryTime = doTrace ? System.nanoTime() : 0;
-    if (qResult == null) {
-      query.closeAll();
-      return Lists.newArrayList();
-    }
-    List<Object[]> list = ensureList(qResult);
     List<ColumnStatistics> result = new ArrayList<ColumnStatistics>(
         Math.min(list.size(), partNames.size()));
     String lastPartName = null;
@@ -1630,9 +1672,7 @@
private ColumnStatisticsObj prepareCSObjWithAdjustedNDV(Object[] row, int i,
       from = i;
       Deadline.checkTimeout();
     }
-
-    timingTrace(doTrace, queryText, start, queryTime);
-    query.closeAll();
+    b.closeAllQueries();
     return result;
   }
 
@@ -1710,4 +1750,48 @@
public void prepareTxn() throws MetaException {
       throw new MetaException("Error setting ansi quotes: " + sqlEx.getMessage());
     }
   }
+
+
+  private static abstract class Batchable<I, R> {
+    private List<Query> queries = null;
+    public abstract List<R> run(List<I> input) throws MetaException;
+    public void addQueryAfterUse(Query query) {
+      if (queries == null) {
+        queries = new ArrayList<Query>(1);
+      }
+      queries.add(query);
+    }
+    protected void addQueryAfterUse(Batchable<?, ?> b) {
+      if (b.queries == null) return;
+      if (queries == null) {
+        queries = new ArrayList<Query>(1);
+      }
+      queries.addAll(b.queries);
+    }
+    public void closeAllQueries() {
+      for (Query q : queries) {
+        try {
+          q.closeAll();
+        } catch (Throwable t) {
+          LOG.error("Failed to close a query", t);
+        }
+      }
+    }
+  }
+
+  private <I,R> List<R> runBatched(List<I> input, Batchable<I, R> runnable) throws MetaException {
+    if (batchSize == NO_BATCHING || batchSize >= input.size()) {
+      return runnable.run(input);
+    }
+    List<R> result = new ArrayList<R>(input.size());
+    for (int fromIndex = 0, toIndex = 0; toIndex < input.size(); fromIndex = toIndex) {
+      toIndex = Math.min(fromIndex + batchSize, input.size());
+      List<I> batchedInput = input.subList(fromIndex, toIndex);
+      List<R> batchedOutput = runnable.run(batchedInput);
+      if (batchedOutput != null) {
+        result.addAll(batchedOutput);
+      }
+    }
+    return result;
+  }
 }
